{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 前言\n",
    "接着上一小节，我们对Huggingface开源代码库中的Bert模型进行了深入学习，这一节我们对如何应用BERT进行详细的讲解。\n",
    "\n",
    "涉及到的jupyter可以在[代码库：篇章3-编写一个Transformer模型：BERT，下载](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A03-%E7%BC%96%E5%86%99%E4%B8%80%E4%B8%AATransformer%E6%A8%A1%E5%9E%8B%EF%BC%9ABERT)\n",
    "\n",
    "本文基于 Transformers 版本 4.4.2（2021 年 3 月 19 日发布）项目中，pytorch 版的 BERT 相关代码，从代码结构、具体实现与原理，以及使用的角度进行分析，包含以下内容：\n",
    "\n",
    "3. BERT-based Models应用模型\n",
    "4. BERT训练和优化\n",
    "5. Bert解决NLP任务\n",
    "  - BertForSequenceClassification\n",
    "  - BertForMultiChoice\n",
    "  - BertForTokenClassification\n",
    "  - BertForQuestionAnswering\n",
    "6. BERT训练与优化\n",
    "7. Pre-Training\n",
    "  - Fine-Tuning\n",
    "  - AdamW\n",
    "  - Warmup\n",
    "\n",
    "## 3-BERT-based Models\n",
    "基于 BERT 的模型都写在/models/bert/modeling_bert.py里面，包括 BERT 预训练模型和 BERT 分类等模型。\n",
    "\n",
    "首先，以下所有的模型都是基于`BertPreTrainedModel`这一抽象基类的，而后者则基于一个更大的基类`PreTrainedModel`。这里我们关注`BertPreTrainedModel`的功能：\n",
    "\n",
    "用于初始化模型权重，同时维护继承自`PreTrainedModel`的一些标记身份或者加载模型时的类变量。\n",
    "下面，首先从预训练模型开始分析。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "### 3.1 BertForPreTraining\n",
    "\n",
    "众所周知，BERT 预训练任务包括两个：\n",
    "\n",
    "- Masked Language Model（MLM）：在句子中随机用`[MASK]`替换一部分单词，然后将句子传入 BERT 中编码每一个单词的信息，最终用`[MASK]`的编码信息预测该位置的正确单词，这一任务旨在训练模型根据上下文理解单词的意思；\n",
    "- Next Sentence Prediction（NSP）：将句子对 A 和 B 输入 BERT，使用`[CLS]`的编码信息进行预测 B 是否 A 的下一句，这一任务旨在训练模型理解预测句子间的关系。\n",
    "\n",
    "\n",
    "![图Bert预训练](./pictures/3-3-bert-lm.png) 图Bert预训练\n",
    "\n",
    "而对应到代码中，这一融合两个任务的模型就是BertForPreTraining，其中包含两个组件：\n",
    "```\n",
    "class BertForPreTraining(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.cls = BertPreTrainingHeads(config)\n",
    "\n",
    "        self.init_weights()\n",
    "    # ...\n",
    "```\n",
    "这里的BertModel在上一章节中已经详细介绍了（注意，这里设置的是默认`add_pooling_layer=True`，即会提取`[CLS]`对应的输出用于 NSP 任务），而`BertPreTrainingHeads`则是负责两个任务的预测模块：\n",
    "```\n",
    "class BertPreTrainingHeads(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.predictions = BertLMPredictionHead(config)\n",
    "        self.seq_relationship = nn.Linear(config.hidden_size, 2)\n",
    "\n",
    "    def forward(self, sequence_output, pooled_output):\n",
    "        prediction_scores = self.predictions(sequence_output)\n",
    "        seq_relationship_score = self.seq_relationship(pooled_output)\n",
    "        return prediction_scores, seq_relationship_score \n",
    "```\n",
    "又是一层封装：`BertPreTrainingHeads`包裹了`BertLMPredictionHead` 和一个代表 NSP 任务的线性层。这里不把 NSP 对应的任务也封装一个`BertXXXPredictionHead`。\n",
    "\n",
    "**其实是有封装这个类的，不过它叫做BertOnlyNSPHead，在这里用不上**\n",
    "\n",
    "继续下探`BertPreTrainingHeads`：\n",
    "```\n",
    "class BertLMPredictionHead(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.transform = BertPredictionHeadTransform(config)\n",
    "\n",
    "        # The output weights are the same as the input embeddings, but there is\n",
    "        # an output-only bias for each token.\n",
    "        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
    "\n",
    "        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n",
    "\n",
    "        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n",
    "        self.decoder.bias = self.bias\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        hidden_states = self.transform(hidden_states)\n",
    "        hidden_states = self.decoder(hidden_states)\n",
    "        return hidden_states\n",
    "```\n",
    "\n",
    "这个类用于预测`[MASK]`位置的输出在每个词作为类别的分类输出，注意到：\n",
    "\n",
    "- 该类重新初始化了一个全 0 向量作为预测权重的 bias；\n",
    "- 该类的输出形状为[batch_size, seq_length, vocab_size]，即预测每个句子每个词是什么类别的概率值（注意这里没有做 softmax）；\n",
    "- 又一个封装的类：BertPredictionHeadTransform，用来完成一些线性变换：\n",
    "```\n",
    "class BertPredictionHeadTransform(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n",
    "        if isinstance(config.hidden_act, str):\n",
    "            self.transform_act_fn = ACT2FN[config.hidden_act]\n",
    "        else:\n",
    "            self.transform_act_fn = config.hidden_act\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.transform_act_fn(hidden_states)\n",
    "        hidden_states = self.LayerNorm(hidden_states)\n",
    "        return hidden_states\n",
    "```\n",
    "\n",
    "回到`BertForPreTraining`，继续看两块 `loss` 是怎么处理的。它的前向传播和BertModel的有所不同，多了`labels`和`next_sentence_label` 两个输入：\n",
    "\n",
    "- labels：形状为[batch_size, seq_length] ，代表 MLM 任务的标签，注意这里对于原本未被遮盖的词设置为 -100，被遮盖词才会有它们对应的 id，和任务设置是反过来的。\n",
    "\n",
    "  - 例如，原始句子是I want to [MASK] an apple，这里我把单词eat给遮住了输入模型，对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100]；\n",
    "  - 为什么要设置为 -100 而不是其他数？因为torch.nn.CrossEntropyLoss默认的ignore_index=-100，也就是说对于标签为 100 的类别输入不会计算 loss。\n",
    "\n",
    "- next_sentence_label：这一个输入很简单，就是 0 和 1 的二分类标签。\n",
    "\n",
    "```\n",
    "# ...\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        next_sentence_label=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ): ...\n",
    "```\n",
    "\n",
    "接下来两部分 loss 的组合：\n",
    "```\n",
    " # ...\n",
    "        total_loss = None\n",
    "        if labels is not None and next_sentence_label is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n",
    "            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))\n",
    "            total_loss = masked_lm_loss + next_sentence_loss\n",
    "        # ...\n",
    "```\n",
    "\n",
    "直接相加，就是这么单纯的策略。\n",
    "当然，这份代码里面也包含了对于只想对单个目标进行预训练的 BERT 模型（具体细节不作展开）：\n",
    "- BertForMaskedLM：只进行 MLM 任务的预训练；\n",
    "  - 基于BertOnlyMLMHead，而后者也是对BertLMPredictionHead的另一层封装；\n",
    "- BertLMHeadModel：这个和上一个的区别在于，这一模型是作为 decoder 运行的版本；\n",
    "  - 同样基于BertOnlyMLMHead；\n",
    "- BertForNextSentencePrediction：只进行 NSP 任务的预训练。\n",
    "  - 基于BertOnlyNSPHead，内容就是一个线性层。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "source": [
    "_CHECKPOINT_FOR_DOC = \"bert-base-uncased\"\n",
    "_CONFIG_FOR_DOC = \"BertConfig\"\n",
    "_TOKENIZER_FOR_DOC = \"BertTokenizer\"\n",
    "from transformers.models.bert.modeling_bert import *\n",
    "from transformers.models.bert.configuration_bert import *\n",
    "class BertForPreTraining(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.cls = BertPreTrainingHeads(config)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def get_output_embeddings(self):\n",
    "        return self.cls.predictions.decoder\n",
    "\n",
    "    def set_output_embeddings(self, new_embeddings):\n",
    "        self.cls.predictions.decoder = new_embeddings\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        next_sentence_label=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):\n",
    "            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n",
    "            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n",
    "            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n",
    "        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):\n",
    "            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair\n",
    "            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:\n",
    "            - 0 indicates sequence B is a continuation of sequence A,\n",
    "            - 1 indicates sequence B is a random sequence.\n",
    "        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):\n",
    "            Used to hide legacy arguments that have been deprecated.\n",
    "        Returns:\n",
    "        Example::\n",
    "from transformers import BertTokenizer, BertForPreTraining\n",
    "import torch\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForPreTraining.from_pretrained('bert-base-uncased')\n",
    "inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n",
    "outputs = model(**inputs)\n",
    "prediction_logits = outputs.prediction_logits\n",
    "seq_relationship_logits = outputs.seq_relationship_logits\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        sequence_output, pooled_output = outputs[:2]\n",
    "        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)\n",
    "\n",
    "        total_loss = None\n",
    "        if labels is not None and next_sentence_label is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n",
    "            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))\n",
    "            total_loss = masked_lm_loss + next_sentence_loss\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (prediction_scores, seq_relationship_score) + outputs[2:]\n",
    "            return ((total_loss,) + output) if total_loss is not None else output\n",
    "\n",
    "        return BertForPreTrainingOutput(\n",
    "            loss=total_loss,\n",
    "            prediction_logits=prediction_scores,\n",
    "            seq_relationship_logits=seq_relationship_score,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )\n",
    "\n",
    "from transformers import BertTokenizer, BertForPreTraining\n",
    "import torch\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForPreTraining.from_pretrained('bert-base-uncased')\n",
    "inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n",
    "outputs = model(**inputs)\n",
    "prediction_logits = outputs.prediction_logits\n",
    "seq_relationship_logits = outputs.seq_relationship_logits"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "source": [
    "@add_start_docstrings(\n",
    "    \"\"\"Bert Model with a `language modeling` head on top for CLM fine-tuning. \"\"\", BERT_START_DOCSTRING\n",
    ")\n",
    "class BertLMHeadModel(BertPreTrainedModel):\n",
    "\n",
    "    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n",
    "    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "\n",
    "        if not config.is_decoder:\n",
    "            logger.warning(\"If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`\")\n",
    "\n",
    "        self.bert = BertModel(config, add_pooling_layer=False)\n",
    "        self.cls = BertOnlyMLMHead(config)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def get_output_embeddings(self):\n",
    "        return self.cls.predictions.decoder\n",
    "\n",
    "    def set_output_embeddings(self, new_embeddings):\n",
    "        self.cls.predictions.decoder = new_embeddings\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        labels=None,\n",
    "        past_key_values=None,\n",
    "        use_cache=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n",
    "            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n",
    "            the model is configured as a decoder.\n",
    "        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
    "            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n",
    "            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n",
    "            - 1 for tokens that are **not masked**,\n",
    "            - 0 for tokens that are **masked**.\n",
    "        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
    "            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n",
    "            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are\n",
    "            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``\n",
    "        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n",
    "            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n",
    "            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n",
    "            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n",
    "            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n",
    "        use_cache (:obj:`bool`, `optional`):\n",
    "            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n",
    "            decoding (see :obj:`past_key_values`).\n",
    "        Returns:\n",
    "        Example::\n",
    "            from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n",
    "            import torch\n",
    "            tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
    "            config = BertConfig.from_pretrained(\"bert-base-cased\")\n",
    "            config.is_decoder = True\n",
    "            model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)\n",
    "            inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n",
    "            outputs = model(**inputs)\n",
    "            prediction_logits = outputs.logits\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "        if labels is not None:\n",
    "            use_cache = False\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            encoder_hidden_states=encoder_hidden_states,\n",
    "            encoder_attention_mask=encoder_attention_mask,\n",
    "            past_key_values=past_key_values,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        sequence_output = outputs[0]\n",
    "        prediction_scores = self.cls(sequence_output)\n",
    "\n",
    "        lm_loss = None\n",
    "        if labels is not None:\n",
    "            # we are doing next-token prediction; shift prediction scores and input ids by one\n",
    "            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()\n",
    "            labels = labels[:, 1:].contiguous()\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (prediction_scores,) + outputs[2:]\n",
    "            return ((lm_loss,) + output) if lm_loss is not None else output\n",
    "\n",
    "        return CausalLMOutputWithCrossAttentions(\n",
    "            loss=lm_loss,\n",
    "            logits=prediction_scores,\n",
    "            past_key_values=outputs.past_key_values,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "            cross_attentions=outputs.cross_attentions,\n",
    "        )\n",
    "\n",
    "    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):\n",
    "        input_shape = input_ids.shape\n",
    "        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n",
    "        if attention_mask is None:\n",
    "            attention_mask = input_ids.new_ones(input_shape)\n",
    "\n",
    "        # cut decoder_input_ids if past is used\n",
    "        if past is not None:\n",
    "            input_ids = input_ids[:, -1:]\n",
    "\n",
    "        return {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"past_key_values\": past}\n",
    "\n",
    "    def _reorder_cache(self, past, beam_idx):\n",
    "        reordered_past = ()\n",
    "        for layer_past in past:\n",
    "            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)\n",
    "        return reordered_past\n",
    "\n",
    "from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n",
    "import torch\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "config = BertConfig.from_pretrained(\"bert-base-uncased\")\n",
    "config.is_decoder = True\n",
    "model = BertLMHeadModel.from_pretrained('bert-base-uncased', config=config)\n",
    "inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n",
    "outputs = model(**inputs)\n",
    "prediction_logits = outputs.logits"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "class BertForNextSentencePrediction(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.cls = BertOnlyNSPHead(config)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
    "            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair\n",
    "            (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:\n",
    "            - 0 indicates sequence B is a continuation of sequence A,\n",
    "            - 1 indicates sequence B is a random sequence.\n",
    "        Returns:\n",
    "        Example::\n",
    "from transformers import BertTokenizer, BertForNextSentencePrediction\n",
    "import torch\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')\n",
    "prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n",
    "next_sentence = \"The sky is blue due to the shorter wavelength of blue light.\"\n",
    "encoding = tokenizer(prompt, next_sentence, return_tensors='pt')\n",
    "outputs = model(**encoding, labels=torch.LongTensor([1]))\n",
    "logits = outputs.logits\n",
    "assert logits[0, 0] < logits[0, 1] # next sentence was random\n",
    "        \"\"\"\n",
    "\n",
    "        if \"next_sentence_label\" in kwargs:\n",
    "            warnings.warn(\n",
    "                \"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.\",\n",
    "                FutureWarning,\n",
    "            )\n",
    "            labels = kwargs.pop(\"next_sentence_label\")\n",
    "\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        pooled_output = outputs[1]\n",
    "\n",
    "        seq_relationship_scores = self.cls(pooled_output)\n",
    "\n",
    "        next_sentence_loss = None\n",
    "        if labels is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (seq_relationship_scores,) + outputs[2:]\n",
    "            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output\n",
    "\n",
    "        return NextSentencePredictorOutput(\n",
    "            loss=next_sentence_loss,\n",
    "            logits=seq_relationship_scores,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )\n",
    "from transformers import BertTokenizer, BertForNextSentencePrediction\n",
    "import torch\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')\n",
    "prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n",
    "next_sentence = \"The sky is blue due to the shorter wavelength of blue light.\"\n",
    "encoding = tokenizer(prompt, next_sentence, return_tensors='pt')\n",
    "outputs = model(**encoding, labels=torch.LongTensor([1]))\n",
    "logits = outputs.logits\n",
    "assert logits[0, 0] < logits[0, 1] # next sentence was random"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Downloading: 100%|██████████| 440M/440M [00:30<00:00, 14.5MB/s]\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "接下来介绍的是各种 Fine-tune 模型，基本都是分类任务：\n",
    "\n",
    "![Bert：finetune](./pictures/3-4-bert-ft.png) 图：Bert：finetune"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "### 3.2 BertForSequenceClassification\n",
    "这一模型用于句子分类（也可以是回归）任务，比如 GLUE benchmark 的各个任务。\n",
    "- 句子分类的输入为句子（对），输出为单个分类标签。\n",
    "\n",
    "结构上很简单，就是`BertModel`（有 pooling）过一个 dropout 后接一个线性层输出分类：\n",
    "```\n",
    "class BertForSequenceClassification(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
    "\n",
    "        self.init_weights()\n",
    "        # ...\n",
    "```\n",
    "\n",
    "在前向传播时，和上面预训练模型一样需要传入labels输入。\n",
    "\n",
    "- 如果初始化的num_labels=1，那么就默认为回归任务，使用 MSELoss；\n",
    "\n",
    "- 否则认为是分类任务。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "source": [
    "@add_start_docstrings(\n",
    "    \"\"\"\n",
    "    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled\n",
    "    output) e.g. for GLUE tasks.\n",
    "    \"\"\",\n",
    "    BERT_START_DOCSTRING,\n",
    ")\n",
    "class BertForSequenceClassification(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "        self.config = config\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        classifier_dropout = (\n",
    "            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob\n",
    "        )\n",
    "        self.dropout = nn.Dropout(classifier_dropout)\n",
    "        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @add_code_sample_docstrings(\n",
    "        tokenizer_class=_TOKENIZER_FOR_DOC,\n",
    "        checkpoint=_CHECKPOINT_FOR_DOC,\n",
    "        output_type=SequenceClassifierOutput,\n",
    "        config_class=_CONFIG_FOR_DOC,\n",
    "    )\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
    "            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n",
    "            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n",
    "            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        pooled_output = outputs[1]\n",
    "\n",
    "        pooled_output = self.dropout(pooled_output)\n",
    "        logits = self.classifier(pooled_output)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            if self.config.problem_type is None:\n",
    "                if self.num_labels == 1:\n",
    "                    self.config.problem_type = \"regression\"\n",
    "                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
    "                    self.config.problem_type = \"single_label_classification\"\n",
    "                else:\n",
    "                    self.config.problem_type = \"multi_label_classification\"\n",
    "\n",
    "            if self.config.problem_type == \"regression\":\n",
    "                loss_fct = MSELoss()\n",
    "                if self.num_labels == 1:\n",
    "                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
    "                else:\n",
    "                    loss = loss_fct(logits, labels)\n",
    "            elif self.config.problem_type == \"single_label_classification\":\n",
    "                loss_fct = CrossEntropyLoss()\n",
    "                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "            elif self.config.problem_type == \"multi_label_classification\":\n",
    "                loss_fct = BCEWithLogitsLoss()\n",
    "                loss = loss_fct(logits, labels)\n",
    "        if not return_dict:\n",
    "            output = (logits,) + outputs[2:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return SequenceClassifierOutput(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "source": [
    "from transformers.models.bert.tokenization_bert import BertTokenizer\n",
    "from transformers.models.bert.modeling_bert import BertForSequenceClassification\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased-finetuned-mrpc\")\n",
    "model = BertForSequenceClassification.from_pretrained(\"bert-base-cased-finetuned-mrpc\")\n",
    "\n",
    "classes = [\"not paraphrase\", \"is paraphrase\"]\n",
    "\n",
    "sequence_0 = \"The company HuggingFace is based in New York City\"\n",
    "sequence_1 = \"Apples are especially bad for your health\"\n",
    "sequence_2 = \"HuggingFace's headquarters are situated in Manhattan\"\n",
    "\n",
    "# The tokekenizer will automatically add any model specific separators (i.e. <CLS> and <SEP>) and tokens to the sequence, as well as compute the attention masks.\n",
    "paraphrase = tokenizer(sequence_0, sequence_2, return_tensors=\"pt\")\n",
    "not_paraphrase = tokenizer(sequence_0, sequence_1, return_tensors=\"pt\")\n",
    "\n",
    "paraphrase_classification_logits = model(**paraphrase).logits\n",
    "not_paraphrase_classification_logits = model(**not_paraphrase).logits\n",
    "\n",
    "paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]\n",
    "not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]\n",
    "\n",
    "# Should be paraphrase\n",
    "for i in range(len(classes)):\n",
    "    print(f\"{classes[i]}: {int(round(paraphrase_results[i] * 100))}%\")\n",
    "\n",
    "# Should not be paraphrase\n",
    "for i in range(len(classes)):\n",
    "    print(f\"{classes[i]}: {int(round(not_paraphrase_results[i] * 100))}%\")"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Downloading: 100%|██████████| 213k/213k [00:00<00:00, 596kB/s]\n",
      "Downloading: 100%|██████████| 29.0/29.0 [00:00<00:00, 12.4kB/s]\n",
      "Downloading: 100%|██████████| 436k/436k [00:00<00:00, 808kB/s]\n",
      "Downloading: 100%|██████████| 433/433 [00:00<00:00, 166kB/s]\n",
      "Downloading: 100%|██████████| 433M/433M [00:29<00:00, 14.5MB/s]\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "not paraphrase: 10%\n",
      "is paraphrase: 90%\n",
      "not paraphrase: 94%\n",
      "is paraphrase: 6%\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "### 3.3 BertForMultipleChoice\n",
    "\n",
    "这一模型用于多项选择，如 RocStories/SWAG 任务。\n",
    "- 多项选择任务的输入为一组分次输入的句子，输出为选择某一句子的单个标签。\n",
    "结构上与句子分类相似，只不过线性层输出维度为 1，即每次需要将每个样本的多个句子的输出拼接起来作为每个样本的预测分数。\n",
    "- 实际上，具体操作时是把每个 batch 的多个句子一同放入的，所以一次处理的输入为[batch_size, num_choices]数量的句子，因此相同 batch 大小时，比句子分类等任务需要更多的显存，在训练时需要小心。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "### 3.4 BertForTokenClassification\n",
    "这一模型用于序列标注（词分类），如 NER 任务。\n",
    "- 序列标注任务的输入为单个句子文本，输出为每个 token 对应的类别标签。\n",
    "由于需要用到每个 token对应的输出而不只是某几个，所以这里的BertModel不用加入 pooling 层；\n",
    "- 同时，这里将`_keys_to_ignore_on_load_unexpected`这一个类参数设置为`[r\"pooler\"]`，也就是在加载模型时对于出现不需要的权重不发生报错。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "source": [
    "class BertForMultipleChoice(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "        self.classifier = nn.Linear(config.hidden_size, 1)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices, sequence_length\"))\n",
    "    @add_code_sample_docstrings(\n",
    "        tokenizer_class=_TOKENIZER_FOR_DOC,\n",
    "        checkpoint=_CHECKPOINT_FOR_DOC,\n",
    "        output_type=MultipleChoiceModelOutput,\n",
    "        config_class=_CONFIG_FOR_DOC,\n",
    "    )\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
    "            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,\n",
    "            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See\n",
    "            :obj:`input_ids` above)\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]\n",
    "\n",
    "        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None\n",
    "        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None\n",
    "        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None\n",
    "        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None\n",
    "        inputs_embeds = (\n",
    "            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))\n",
    "            if inputs_embeds is not None\n",
    "            else None\n",
    "        )\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        pooled_output = outputs[1]\n",
    "\n",
    "        pooled_output = self.dropout(pooled_output)\n",
    "        logits = self.classifier(pooled_output)\n",
    "        reshaped_logits = logits.view(-1, num_choices)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(reshaped_logits, labels)\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (reshaped_logits,) + outputs[2:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return MultipleChoiceModelOutput(\n",
    "            loss=loss,\n",
    "            logits=reshaped_logits,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "@add_start_docstrings(\n",
    "    \"\"\"\n",
    "    Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for\n",
    "    Named-Entity-Recognition (NER) tasks.\n",
    "    \"\"\",\n",
    "    BERT_START_DOCSTRING,\n",
    ")\n",
    "class BertForTokenClassification(BertPreTrainedModel):\n",
    "\n",
    "    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "\n",
    "        self.bert = BertModel(config, add_pooling_layer=False)\n",
    "        classifier_dropout = (\n",
    "            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob\n",
    "        )\n",
    "        self.dropout = nn.Dropout(classifier_dropout)\n",
    "        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @add_code_sample_docstrings(\n",
    "        tokenizer_class=_TOKENIZER_FOR_DOC,\n",
    "        checkpoint=_CHECKPOINT_FOR_DOC,\n",
    "        output_type=TokenClassifierOutput,\n",
    "        config_class=_CONFIG_FOR_DOC,\n",
    "    )\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        labels=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
    "            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n",
    "            1]``.\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        sequence_output = outputs[0]\n",
    "\n",
    "        sequence_output = self.dropout(sequence_output)\n",
    "        logits = self.classifier(sequence_output)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            # Only keep active parts of the loss\n",
    "            if attention_mask is not None:\n",
    "                active_loss = attention_mask.view(-1) == 1\n",
    "                active_logits = logits.view(-1, self.num_labels)\n",
    "                active_labels = torch.where(\n",
    "                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n",
    "                )\n",
    "                loss = loss_fct(active_logits, active_labels)\n",
    "            else:\n",
    "                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (logits,) + outputs[2:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return TokenClassifierOutput(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "source": [
    "from transformers import BertForTokenClassification, BertTokenizer\n",
    "import torch\n",
    "\n",
    "model = BertForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n",
    "\n",
    "label_list = [\n",
    "\"O\",       # Outside of a named entity\n",
    "\"B-MISC\",  # Beginning of a miscellaneous entity right after another miscellaneous entity\n",
    "\"I-MISC\",  # Miscellaneous entity\n",
    "\"B-PER\",   # Beginning of a person's name right after another person's name\n",
    "\"I-PER\",   # Person's name\n",
    "\"B-ORG\",   # Beginning of an organisation right after another organisation\n",
    "\"I-ORG\",   # Organisation\n",
    "\"B-LOC\",   # Beginning of a location right after another location\n",
    "\"I-LOC\"    # Location\n",
    "]\n",
    "\n",
    "sequence = \"Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge.\"\n",
    "\n",
    "# Bit of a hack to get the tokens with the special tokens\n",
    "tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))\n",
    "inputs = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
    "\n",
    "outputs = model(inputs).logits\n",
    "predictions = torch.argmax(outputs, dim=2)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Downloading: 100%|██████████| 998/998 [00:00<00:00, 382kB/s]\n",
      "Downloading: 100%|██████████| 1.33G/1.33G [01:30<00:00, 14.7MB/s]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "for token, prediction in zip(tokens, predictions[0].numpy()):\n",
    "    print((token, model.config.id2label[prediction]))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "('[CLS]', 'O')\n",
      "('Hu', 'I-ORG')\n",
      "('##gging', 'I-ORG')\n",
      "('Face', 'I-ORG')\n",
      "('Inc', 'I-ORG')\n",
      "('.', 'O')\n",
      "('is', 'O')\n",
      "('a', 'O')\n",
      "('company', 'O')\n",
      "('based', 'O')\n",
      "('in', 'O')\n",
      "('New', 'I-LOC')\n",
      "('York', 'I-LOC')\n",
      "('City', 'I-LOC')\n",
      "('.', 'O')\n",
      "('Its', 'O')\n",
      "('headquarters', 'O')\n",
      "('are', 'O')\n",
      "('in', 'O')\n",
      "('D', 'I-LOC')\n",
      "('##UM', 'I-LOC')\n",
      "('##BO', 'I-LOC')\n",
      "(',', 'O')\n",
      "('therefore', 'O')\n",
      "('very', 'O')\n",
      "('close', 'O')\n",
      "('to', 'O')\n",
      "('the', 'O')\n",
      "('Manhattan', 'I-LOC')\n",
      "('Bridge', 'I-LOC')\n",
      "('.', 'O')\n",
      "('[SEP]', 'O')\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "### 3.5 BertForQuestionAnswering\n",
    "这一模型用于解决问答任务，例如 SQuAD 任务。\n",
    "- 问答任务的输入为问题 +（对于 BERT 只能是一个）回答组成的句子对，输出为起始位置和结束位置用于标出回答中的具体文本。\n",
    "这里需要两个输出，即对起始位置的预测和对结束位置的预测，两个输出的长度都和句子长度一样，从其中挑出最大的预测值对应的下标作为预测的位置。\n",
    "- 对超出句子长度的非法 label，会将其压缩（torch.clamp_）到合理范围。\n",
    "\n",
    "以上就是关于 BERT 源码的介绍，下面介绍一些关于 BERT 模型实用的训练细节。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "@add_start_docstrings(\n",
    "    \"\"\"\n",
    "    Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear\n",
    "    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).\n",
    "    \"\"\",\n",
    "    BERT_START_DOCSTRING,\n",
    ")\n",
    "class BertForQuestionAnswering(BertPreTrainedModel):\n",
    "\n",
    "    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "\n",
    "        self.bert = BertModel(config, add_pooling_layer=False)\n",
    "        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @add_code_sample_docstrings(\n",
    "        tokenizer_class=_TOKENIZER_FOR_DOC,\n",
    "        checkpoint=_CHECKPOINT_FOR_DOC,\n",
    "        output_type=QuestionAnsweringModelOutput,\n",
    "        config_class=_CONFIG_FOR_DOC,\n",
    "    )\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        start_positions=None,\n",
    "        end_positions=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
    "            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n",
    "            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n",
    "            sequence are not taken into account for computing the loss.\n",
    "        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
    "            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n",
    "            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n",
    "            sequence are not taken into account for computing the loss.\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "\n",
    "        sequence_output = outputs[0]\n",
    "\n",
    "        logits = self.qa_outputs(sequence_output)\n",
    "        start_logits, end_logits = logits.split(1, dim=-1)\n",
    "        start_logits = start_logits.squeeze(-1).contiguous()\n",
    "        end_logits = end_logits.squeeze(-1).contiguous()\n",
    "\n",
    "        total_loss = None\n",
    "        if start_positions is not None and end_positions is not None:\n",
    "            # If we are on multi-GPU, split add a dimension\n",
    "            if len(start_positions.size()) > 1:\n",
    "                start_positions = start_positions.squeeze(-1)\n",
    "            if len(end_positions.size()) > 1:\n",
    "                end_positions = end_positions.squeeze(-1)\n",
    "            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n",
    "            ignored_index = start_logits.size(1)\n",
    "            start_positions = start_positions.clamp(0, ignored_index)\n",
    "            end_positions = end_positions.clamp(0, ignored_index)\n",
    "\n",
    "            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n",
    "            start_loss = loss_fct(start_logits, start_positions)\n",
    "            end_loss = loss_fct(end_logits, end_positions)\n",
    "            total_loss = (start_loss + end_loss) / 2\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (start_logits, end_logits) + outputs[2:]\n",
    "            return ((total_loss,) + output) if total_loss is not None else output\n",
    "\n",
    "        return QuestionAnsweringModelOutput(\n",
    "            loss=total_loss,\n",
    "            start_logits=start_logits,\n",
    "            end_logits=end_logits,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "source": [
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering\n",
    "import torch\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\")\n",
    "model = AutoModelForQuestionAnswering.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\")\n",
    "\n",
    "text = \"🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.\"\n",
    "\n",
    "questions = [\n",
    "\"How many pretrained models are available in 🤗 Transformers?\",\n",
    "\"What does 🤗 Transformers provide?\",\n",
    "\"🤗 Transformers provides interoperability between which frameworks?\",\n",
    "]\n",
    "\n",
    "for question in questions:\n",
    "    inputs = tokenizer(question, text, add_special_tokens=True, return_tensors=\"pt\")\n",
    "    input_ids = inputs[\"input_ids\"].tolist()[0]\n",
    "    outputs = model(**inputs)\n",
    "    answer_start_scores = outputs.start_logits\n",
    "    answer_end_scores = outputs.end_logits\n",
    "    answer_start = torch.argmax(\n",
    "        answer_start_scores\n",
    "    )  # Get the most likely beginning of answer with the argmax of the score\n",
    "    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score\n",
    "    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))\n",
    "    print(f\"Question: {question}\")\n",
    "    print(f\"Answer: {answer}\")"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Downloading: 100%|██████████| 443/443 [00:00<00:00, 186kB/s]\n",
      "Downloading: 100%|██████████| 232k/232k [00:00<00:00, 438kB/s]\n",
      "Downloading: 100%|██████████| 466k/466k [00:00<00:00, 845kB/s]\n",
      "Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 10.5kB/s]\n",
      "Downloading: 100%|██████████| 1.34G/1.34G [01:28<00:00, 15.1MB/s]\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Question: How many pretrained models are available in 🤗 Transformers?\n",
      "Answer: over 32 +\n",
      "Question: What does 🤗 Transformers provide?\n",
      "Answer: general - purpose architectures\n",
      "Question: 🤗 Transformers provides interoperability between which frameworks?\n",
      "Answer: tensorflow 2. 0 and pytorch\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "*** \n",
    "## BERT训练和优化\n",
    "### 4.1 Pre-Training\n",
    "预训练阶段，除了众所周知的 15%、80% mask 比例，有一个值得注意的地方就是参数共享。\n",
    "不止 BERT，所有 huggingface 实现的 PLM 的 word embedding 和 masked language model 的预测权重在初始化过程中都是共享的：\n",
    "```\n",
    "class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):\n",
    "    # ...\n",
    "    def tie_weights(self):\n",
    "        \"\"\"\n",
    "        Tie the weights between the input embeddings and the output embeddings.\n",
    "\n",
    "        If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning\n",
    "        the weights instead.\n",
    "        \"\"\"\n",
    "        output_embeddings = self.get_output_embeddings()\n",
    "        if output_embeddings is not None and self.config.tie_word_embeddings:\n",
    "            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())\n",
    "\n",
    "        if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:\n",
    "            if hasattr(self, self.base_model_prefix):\n",
    "                self = getattr(self, self.base_model_prefix)\n",
    "            self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)\n",
    "    # ...\n",
    "```\n",
    "\n",
    "至于为什么，应该是因为 word_embedding 和 prediction 权重太大了，以 bert-base 为例，其尺寸为(30522, 768)，降低训练难度。\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "***\n",
    "### 4.2 Fine-Tuning\n",
    "微调也就是下游任务阶段，也有两个值得注意的地方。\n",
    "#### 4.2.1 AdamW\n",
    "首先介绍一下 BERT 的优化器：AdamW（AdamWeightDecayOptimizer）。\n",
    "\n",
    "这一优化器来自 ICLR 2017 的 Best Paper：《Fixing Weight Decay Regularization in Adam》中提出的一种用于修复 Adam 的权重衰减错误的新方法。论文指出，L2 正则化和权重衰减在大部分情况下并不等价，只在 SGD 优化的情况下是等价的；而大多数框架中对于 Adam+L2 正则使用的是权重衰减的方式，两者不能混为一谈。\n",
    "\n",
    "AdamW 是在 Adam+L2 正则化的基础上进行改进的算法，与一般的 Adam+L2 的区别如下：\n",
    "\n",
    "![图：AdamW](./pictures/3-5-adamw.png) 图：AdamW\n",
    "\n",
    "关于 AdamW 的分析可以参考：\n",
    "\n",
    "- AdamW and Super-convergence is now the fastest way to train neural nets [1]\n",
    "- paperplanet：都 9102 年了，别再用 Adam + L2 regularization了 [2]\n",
    "\n",
    "通常，我们会选择模型的 weight 部分参与 decay 过程，而另一部分（包括 LayerNorm 的 weight）不参与（代码最初来源应该是 Huggingface 的示例）\n",
    "补充：关于这么做的理由，我暂时没有找到合理的解答。\n",
    "\n",
    "```\n",
    "# model: a Bert-based-model object\n",
    "    # learning_rate: default 2e-5 for text classification\n",
    "    param_optimizer = list(model.named_parameters())\n",
    "    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "    optimizer_grouped_parameters = [\n",
    "        {'params': [p for n, p in param_optimizer if not any(\n",
    "            nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
    "        {'params': [p for n, p in param_optimizer if any(\n",
    "            nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]\n",
    "    optimizer = AdamW(optimizer_grouped_parameters,\n",
    "                      lr=learning_rate)\n",
    "    # ...\n",
    "```\n",
    "\n",
    "#### 4.2.2 Warmup\n",
    "\n",
    "BERT 的训练中另一个特点在于 Warmup，其含义为：\n",
    "\n",
    "在训练初期使用较小的学习率（从 0 开始），在一定步数（比如 1000 步）内逐渐提高到正常大小（比如上面的 2e-5），避免模型过早进入局部最优而过拟合；\n",
    "- 在训练后期再慢慢将学习率降低到 0，避免后期训练还出现较大的参数变化。\n",
    "- 在 Huggingface 的实现中，可以使用多种 warmup 策略：\n",
    "```\n",
    "TYPE_TO_SCHEDULER_FUNCTION = {\n",
    "    SchedulerType.LINEAR: get_linear_schedule_with_warmup,\n",
    "    SchedulerType.COSINE: get_cosine_schedule_with_warmup,\n",
    "    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,\n",
    "    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,\n",
    "    SchedulerType.CONSTANT: get_constant_schedule,\n",
    "    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,\n",
    "}\n",
    "```\n",
    "具体而言：\n",
    "- CONSTANT：保持固定学习率不变；\n",
    "- CONSTANT_WITH_WARMUP：在每一个 step 中线性调整学习率；\n",
    "- LINEAR：上文提到的两段式调整；\n",
    "- COSINE：和两段式调整类似，只不过采用的是三角函数式的曲线调整；\n",
    "- COSINE_WITH_RESTARTS：训练中将上面 COSINE 的调整重复 n 次；\n",
    "- POLYNOMIAL：按指数曲线进行两段式调整。\n",
    "具体使用参考transformers/optimization.py：\n",
    "最常用的还是get_linear_scheduler_with_warmup即线性两段式调整学习率的方案。\n",
    "\n",
    "```\n",
    "def get_scheduler(\n",
    "    name: Union[str, SchedulerType],\n",
    "    optimizer: Optimizer,\n",
    "    num_warmup_steps: Optional[int] = None,\n",
    "    num_training_steps: Optional[int] = None,\n",
    "): ...\n",
    "\n",
    "```\n",
    "\n",
    "以上即为关于 transformers 库（4.4.2 版本）中 BERT 应用的相关代码的具体实现分析，欢迎与读者共同交流探讨。\n",
    "\n",
    "## 致谢\n",
    "本文主要由浙江大学李泺秋撰写，本项目同学负责整理和汇总。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}